%%
% Process 'SpermTracks.mat' data and calculate basis modes.
% Variable species can be one of: 'Abalone', 'Arbacia', 'Bull', 'Ciona',
% 'Human-low', 'Human-high', 'Lpictus', 'Lvariegatus', 'Spurpuratus',
% 'Zebrafish'
%%


clear all; %#ok<*CLSCR>
close all;

% Fetch parameters stored in fetch_parameters.m
filepath = [pwd filesep '\data\OriginalData' filesep];
species = 'Bull';

switch species
    case {'Bull','Human-low','Human-high'}
        iii = 2; jjj = 25;
    otherwise
        iii = 5; jjj = 10;
end
chkrotation = false;

thresh = 0.0;

params = fetch_parameters(species);
%If running the velocity calc for eccentric, manually set eccentric to zero here:
%params(1).eccentric = 0;
file_nums = params(1).file_nums;
um2px = params(1).space_res;

sd = 30; % Standard deviation of Gaussian filter
fx = 200; % Number of timesteps to include in filter
filt = fspecial('gaussian', [fx, 1], sd);

eccentricity_col = 5; % Column with eccentricity info
run_num_col = 8; % Column with run number info from SpermTracks.mat
%rng(123) % Set random seed for repeatable results

N = 2;
V_max = 200;
tail_pts_to_remove = params(1).tail_pts_to_remove;

cnt = 1; % Initialize count
tot = 0;
for k = file_nums
    % Import 'head' and 'tail' data
    try
        load([filepath params(1).folder num2str(k) ...
            '_analysis' filesep 'SpermTracks.mat']);
    catch
        load([filepath params(1).folder(1:(end-1)) filesep num2str(k) ...
            '_analysis' filesep 'SpermTracks.mat']);
        [filepath params(1).folder(1:(end-1)) filesep num2str(k) ...
            '_analysis' filesep 'SpermTracks.mat']
    end
    % Take subset of columns for head tracked data
    head = head(:, [1:2 eccentricity_col run_num_col]);
    
    % Get unique run numbers
    runs = unique(head(:, 4));
    
    % Iterate through each run number and find time steps with tail data
    for i = 1:length(runs)
        
        ind0 = find(head(:, 4) == runs(i));
        
        for j = 1:length(ind0)
            try
                % Clean the tail of bad intensity data for the fixed length tests
                tailFlag = (tail{ind0(j), 1}(:,3) > 0.001 | tail{ind0(j), 1}(:,3)==0);
                if sum(tailFlag) == 0
                    %length(idx+1) = NaN;
                    %idx = idx+1;
                else
                    tailFlag = [tailFlag,tailFlag,tailFlag];
                    
                    CC = bwconncomp(tailFlag);
                    tailFlag(:)=0;
                    numPixels = cellfun(@numel,CC.PixelIdxList);
                    [biggest,nn] = max(numPixels);
                    tailFlag(CC.PixelIdxList{nn}) = 1;
                    tailFlag = tailFlag(:,2);
                    nn = find(tailFlag,1,'first');
                    tailFlag(1:nn) = 1;
                    
                    if sum(tailFlag)<10
                        tail{ind0(j), 1} = [];
                    else
                        
                        tail{ind0(j), 1} = tail{ind0(j), 1}(tailFlag,:);
                    end
                end
            end
        end
        
        
        if isempty(ind0)
            continue
        end
        %
        if params(1).eccentric == 1%1
            % If tracking is done by eccentricity, find timesteps with
            % peak eccentricity
            allruns(cnt).ecc = real(head(ind0, 3));
            
            mask = find(allruns(cnt).ecc == 0);
            
            [pk, lc] = findpeaks(-allruns(cnt).ecc);
            %Changed to be > from < on 06/17/2020
            lc = lc(pk > mean(pk));
            
            ind0 = ind0(lc(~ismember(lc, mask)));
            %ind0 = ind0(allruns(cnt).ecc<-mean(pk));
            
            if isempty(ind0)
                continue
            end
            
        end
        
        % Save head and tail points
        allruns(cnt).head = head(ind0, :);
        tail_run = tail(ind0);
        allruns(cnt).tail = tail_run(~cellfun('isempty', tail_run));
        
        % Filter trajectory of head track
        Ux = imfilter(allruns(cnt).head(:, 1), filt, 'same', 'replicate');
        Uy = imfilter(allruns(cnt).head(:, 2), filt, 'same', 'replicate');
        
        % Differential trajectory from neutral axis
        dUx = Ux - allruns(cnt).head(:, 1);
        dUy = Uy - allruns(cnt).head(:, 2);
        
        % Calculate median velocity
        allruns(cnt).vel = nanmedian(sqrt(diff(Ux).^2 + diff(Uy).^2)) * ...
            params(1).time_res * um2px;
        allruns(cnt).newVel = sqrt(diff(Ux).^2 + diff(Uy).^2) *params(1).time_res * um2px;
        
        
        % Calculate length
        for m = 1:size(allruns(cnt).tail,1)
            
            inc_length = diff(allruns(cnt).tail{m, 1}(:,1:2));
            allruns(cnt).length{m, 1} = sqrt(sum(inc_length.^2,2));
        end
        
        procdata(cnt).head_filt(:, 1) = Ux;
        procdata(cnt).head_filt(:, 2) = Uy;
        
        cnt = cnt + 1;
        
    end
    tot = tot+head(end,4)
end

% Find all runs with tail data
n_allruns = [];
for i = 1 : length(allruns)
    if size(allruns(i).tail, 1) > 0
        n_allruns = [n_allruns i];
    end
end

if (params(1).variable_length == 1)
    Vmax = 1e5;
end

%% Process data (filter, throw out bad points by measuring the convex hull
% of consecutive flagella shapes. Remove erratic flagella that make volume
% of convex hull large.
% If measuring by eccentric or if 'Bull', skip this step
if (params(1).eccentric == 0) && (~strcmp(species, 'Bull'))
    
    for i = n_allruns
        
        n_timesteps = length(allruns(i).tail);
        len = length(allruns(i).tail{1}(1:end-tail_pts_to_remove, 1)); % length of tail
        V = 1e6 * ones(n_timesteps, 1); % initialize convex hull volume
        
        % Initalize tail with copies of the first tail points
        tailx = repmat(allruns(i).tail{1}(1:end-tail_pts_to_remove, 1) - ...
            allruns(i).tail{1}(1, 1), N, 1);
        taily = repmat(allruns(i).tail{1}(1:end-tail_pts_to_remove, 2) - ...
            allruns(i).tail{1}(1, 2), N, 1);
        tail_lengths = len * ones(N, 1);
        
        for k = 1 : n_timesteps
            
            % Add new tail to vector
            tailx = [tailx; ...
                allruns(i).tail{k}(1:end-tail_pts_to_remove, 1) - ...
                allruns(i).tail{k}(1, 1)]; %#ok<*AGROW>
            taily = [taily; ...
                allruns(i).tail{k}(1:end-tail_pts_to_remove, 2) - ...
                allruns(i).tail{k}(1, 2)];
            tail_length0 = length(allruns(i).tail{k}(1:end-tail_pts_to_remove, 1));
            tail_lengths = [tail_lengths; tail_length0];
            
            % Compute convex hull
            [K, V(k)] = convhull(tailx, taily);
            
            % Delete first tail from vector
            tailx(1:tail_lengths(1)) = [];
            taily(1:tail_lengths(1)) = [];
            tail_lengths(1) = [];
            
        end
        
        % Normalize V
        V = V - mean(V(V < 1e5));
        
        % Keep only time steps with small enough convex hull volume
        allruns(i).tail = allruns(i).tail(V < V_max);
        
        clear tailx taily V
    end
end

tail_length = length(allruns(1).tail{1}(1 : end - tail_pts_to_remove, 1));

if params(1).variable_length == 0
    tail_length = 30;
end

orig_vec = linspace(0, 1, tail_length);

A_allruns = [];

for i = n_allruns %runs for each experiment with both head and tail data
    
    % If variable length, interpolate each flagella.
    if params(1).variable_length == 1
        
        s0 = linspace(0, 1, 30);
        orig_vec = s0;
        
        for j = 1 : length(allruns(i).tail)
            len = 0;
            
            x = allruns(i).tail{j}(1:end-tail_pts_to_remove, 1);
            y = allruns(i).tail{j}(1:end-tail_pts_to_remove, 2);
            
            s = cumsum(diff(x).^2 + diff(y).^2);
            s = [0; s];
            len = max([s; len]);
            s = s / max(s);
            
            allruns(i).tail_final{j}(:, 1) = interp1(s, x, s0);
            allruns(i).tail_final{j}(:, 2) = interp1(s, y, s0);
            
        end
        
    else
        
        allruns(i).tail_final = allruns(i).tail;
        
    end
    
    % Find curvature of flagella for each timepoint
    for j = 1 : length(allruns(i).tail_final)
        smoothx = smooth(allruns(i).tail_final{j}(:,1),'sgolay');
        smoothy = smooth(allruns(i).tail_final{j}(:,2),'sgolay');
        A(:, j) = calculate_curvature([smoothx smoothy]*um2px, orig_vec);
        allruns(i).curvature{j} = A(:,j);
    end
    
    A_allruns = [A_allruns A];clear A;
    
end

%% Calculate basis modes

N_min_samples = 100;

for nidx=n_allruns
    
    [U, S, V] = svd(cell2mat(allruns(nidx).curvature));
    
    sv = diag(S);
    
    procdata(nidx).U = U;
    procdata(nidx).S = cumsum(sv.^2) ./ sum(sv.^2);
    
    gooddata(nidx) = true;
    switch species
        case 'Human-high'
            if sum(abs(U(:, 1))>0.6 | abs(U(:, 2))>0.6)>0
                gooddata(nidx) = false;
            end
        case 'Human-low'
            if sum(abs(U(:, 1))>0.5 | abs(U(:, 2))>0.5)>0
                gooddata(nidx) = false;
            end
        otherwise
            if sum(abs(U(:, 1))>0.4 | abs(U(:, 2))>0.4)>0
                gooddata(nidx) = false;
            end
    end
    
    S1(nidx,:) = procdata(nidx).S(1:3);
    
    procdata(nidx).A = cell2mat(allruns(nidx).curvature);
    procdata(nidx).stroke = U\procdata(nidx).A;
    
    
    term_2 = zeros(size(procdata(nidx).A));
    term_all = zeros(size(procdata(nidx).A));
    for p = 1:size(procdata(nidx).A,2)
        term_2(:,p) =  procdata(nidx).stroke(1,p)*U(:,1)+procdata(nidx).stroke(2,p)*U(:,2);
        term_all(:,p) = 0;
        for jj = 1:size(U,2)
            term_all(:,p) = term_all(:,p)+ procdata(nidx).stroke(jj,p)*U(:,jj);
        end
    end
    
    strokemag = sqrt(max(procdata(nidx).stroke(:,1))^2+max(procdata(nidx).stroke(:,2))^2);
    
    
    %%incorporate rotation
    X = [procdata(nidx).stroke(1,:); procdata(nidx).stroke(2,:);]';
    Xd = sqrt(procdata(nidx).stroke(1,:).^2+procdata(nidx).stroke(2,:).^2);
    isOX = isoutlier(Xd);
    G = ~isOX;
    idx = G==1;
    ED = fit_ellipse(X(idx,1),X(idx,2));
    try
        Mu = [ED.X0, ED.Y0];
    catch
        Mu = [0,0];
    end
    if isempty(Mu), Mu = [0,0]; end
    X0 = bsxfun(@minus, X(idx,:), Mu);
    
    try
        ang = 180*-ED.phi/pi;
    catch
        ang=0;
    end
    
    if isempty(ang), ang = 0; end
    
    %Species exception to rotation
    switch species
        case 'Bull'
            X0 = X;
            ang = 0;
       
    end
    
   
    Rinv = [cosd(ang) -sind(ang); sind(ang) cosd(ang)];
    procdata(nidx).Urot = procdata(nidx).U(:,1:2)*Rinv;
    procdata(nidx).strokerot = Rinv'*procdata(nidx).stroke(1:2,:);
    
    switch species
        case 'Bull'
            newED.X0 = 0; newED.Y0 = 0;
    end

    
    try
        procdata(nidx).strokerot = bsxfun(@minus,procdata(nidx).strokerot,[newED.X0; newED.Y0]);
    catch
        
    end
    procdata(nidx).strokeold =  procdata(nidx).stroke;
    procdata(nidx).stroke =  procdata(nidx).strokerot;
    
    %Residual calculation
    resid_all = (procdata(nidx).A-term_all).^2;
    resid_all_perimg = sum(resid_all,1)/30/strokemag;
    resid_all_medianpertrack = median(resid_all_perimg);
    
    resid_2terms = (procdata(nidx).A-term_2).^2;
    resid_2terms_perimg = sum(resid_2terms,1)/30/strokemag;
    resid_2terms_medianpertrack = median(resid_2terms_perimg);
    
    procdata(nidx).A_2terms = term_2;
    
    procdata(nidx).residuals = struct();
    procdata(nidx).residuals.resid_all = resid_all;
    procdata(nidx).residuals.resid_all_meanpertrack = resid_all_medianpertrack;
    procdata(nidx).residuals.resid_2terms = resid_2terms;
    procdata(nidx).residuals.resid_2terms_perimg  = resid_2terms_perimg;
    procdata(nidx).residuals.resid_2terms_meanpertrack = resid_2terms_medianpertrack;
    
    all2termresids(nidx) = resid_2terms_medianpertrack;
    
    if (strcmp(species,'Bull')||strcmp(species,'Human-low')||strcmp(species,'Human-high')) && resid_2terms_medianpertrack > 0.4
        resid_2terms_medianpertrack
        gooddata(nidx) = false;
    elseif ~(strcmp(species,'Bull')||strcmp(species,'Human-low')||strcmp(species,'Human-high')) && resid_2terms_medianpertrack>0.2
        gooddata(nidx) = false;
    end
    
    if gooddata(nidx)
        figure(nidx)
        subplot(1,2,1); plot_modes(U, species)
        ylim([-0.4 0.4]);
        subplot(1,2,2); plot_modes(procdata(nidx).Urot,species);
        ylim([-0.4 0.4]);
        
        
        figure(775); subplot(1,2,1); scatter(procdata(nidx).strokeold(1,:),procdata(nidx).strokeold(2,:),'MarkerEdgeColor','r'); axis image;
        subplot(1,2,2); scatter(procdata(nidx).strokerot(1,:),procdata(nidx).strokerot(2,:),'MarkerEdgeColor','b'); axis image;
        
        disp(['Track ' num2str(nidx) '/' num2str(max(n_allruns)) ', Tr.Len: ' num2str(p) ' Residual for 2 terms: ' num2str(resid_2terms_medianpertrack)]);
        
    end
    
    
end

close all;
%Temp stuff for table S1
for i=size(S1,1):-1:1
    if sum(S1(i,:)) == 0
        S1(i,:) = [];
    end
end
cumEnergy = nanmean(S1,1);
cumStdev = nanstd(S1,[],1);
disp(['S1.Mean cumulative 2nd mode energy/track (of all tracks): ' num2str(cumEnergy(2)) ' pm ' num2str(cumStdev(2))])


goodidx = find(gooddata');
tf = logical(sum(bsxfun(@eq,n_allruns,goodidx),1));
new_n_allruns = n_allruns(tf);

disp(['S1.Number of included sperm tracks: ' num2str(sum(tf)) '/' num2str(length(tf))])
OTcumEnergy = nanmean(S1(tf,:),1);
OTcumStdev = nanstd(S1(tf,:),[],1);
disp(['S1.Mean cumulative 2nd mode energy/track (of overthresh): ' num2str(OTcumEnergy(2)) ' pm ' num2str(OTcumStdev(2))])

strokeRot = [];
for nidx = new_n_allruns
    strokeRot = [strokeRot, procdata(nidx).stroke];
end


%%%%Run Make Table S1 here if desired%%%%
%%

ls = 0; cnt = 1;
A_overthresh = [];
maxlen = 500;
for nidx = new_n_allruns

    A_temp = procdata(nidx).A;
    procdata(nidx).A_overthresh = A_temp;
    ls(cnt) = size(A_temp,2);
    A_overthresh = [A_overthresh, A_temp]; clear A_temp;
    cnt = cnt+1;
    
end
sum_ls = cumsum(ls);

clear A_temp;

stroke1all = [];
stroke2all = [];
figure(200)
for j =1:jjj
    idx = ceil(iii*rand(size(A_overthresh,2),1));
    
    for i = 1:iii
        A_temp = A_overthresh(:,idx==i);
        
        [U_temp{i+iii*(j-1)}, S_temp{i+iii*(j-1)}, ~] = svd(single(A_temp));
        
        for k =1:2
            difft = mean(diff(U_temp{i+iii*(j-1)}(1:3,k)));
            if difft <0
                U_temp{i+iii*(j-1)}(:,k) = -U_temp{i+iii*(j-1)}(:,k);
            end
        end
        
            strokebrk = U_temp{i+iii*(j-1)}\A_temp;
        
        stroke1{i,j} = strokebrk(1,:);
        stroke2{i,j} = strokebrk(2,:);
        
        [rotU_temp,ang] = rotateU(U_temp{i+iii*(j-1)},strokebrk,A_temp,species);
        
        U_tempOrig = U_temp{i+iii*(j-1)};
        U_temp{i+iii*(j-1)}(:,1:2) = rotU_temp;
        
        stroke1all = [stroke1all, stroke1{i,j}];
        stroke2all = [stroke2all, stroke2{i,j}];
        
        if chkrotation
            figure(120);
            plot_modes(U_temp{i+iii*(j-1)}, species);
            hold on;
            
            strkOrig = U_tempOrig\A_temp;
            strkRot = rotU_temp\A_temp;
            
            figure(776); subplot(1,2,1); scatter(strkOrig(1,:),strkOrig(2,:),'MarkerEdgeColor','r'); axis image;
            %hold on;
            subplot(1,2,2); scatter(strkRot(1,:),strkRot(2,:),'MarkerEdgeColor','b'); axis image;
            
        end
    end
end


X = [stroke1all', stroke2all';];
Xd = sqrt(stroke1all.^2+stroke2all.^2);
isOX = isoutlier(Xd);
G = ~isOX;


h = figure(12302);
ellipsedata = fit_ellipse(stroke1all(G),stroke2all(G),h); hold on;
res=10;
ellipsedata.xy = double(ellipsedata.xy);
ellipseBW = poly2mask((ellipsedata.xy(1,:)+60)*res+1,(ellipsedata.xy(2,:)+60)*res+1,120*res+1,120*res+1);
imagesc(-60:(1/res):60,-60:(1/res):60,ellipseBW)
plot(stroke1all,stroke2all,' .');

scmax = 1.5; scmin = 0.01;
sc = 1; %1.15 for 97.5%
uTh = 0.95; lTh = 0.05; %currently 90% confidence interval
in_bestfit = inpolygon(stroke1all,stroke2all,ellipsedata.xy(1,:)*sc,ellipsedata.xy(2,:)*sc);
ratio_bestfit = sum(in_bestfit)/length(in_bestfit);
% find the higher bound
s_ = [sc scmax]; ratio = 0;
count = 0;
while abs(ratio-uTh)>0.001 && count<20
    in_temp = inpolygon(stroke1all,stroke2all, ellipsedata.xy(1,:)*mean(s_),ellipsedata.xy(2,:)*mean(s_));
    ratio = sum(in_temp)/length(in_temp); ell_sc(2) = mean(s_);
    if ratio < uTh
        s_(1) = mean(s_);
    else
        s_(2) = mean(s_);
    end
    count = count+1;
end
% find the lower bound
s_ = [scmin sc]; ratio = 0; count = 0;
while abs(ratio-lTh)>0.001 && count<20
    in_temp = inpolygon(stroke1all,stroke2all, ellipsedata.xy(1,:)*mean(s_),ellipsedata.xy(2,:)*mean(s_));
    ratio = sum(in_temp)/length(in_temp); ell_sc(1) = mean(s_);
    if ratio < lTh
        s_(1) = mean(s_);
    else
        s_(2) = mean(s_);
    end
    count = count+1;
end

figure(12303)
n = ndhist(stroke1all(G),stroke2all(G),'int',1,'bins',5,'axis',[-60 60 -60 60]); axis image; colorbar;
hold on;
title([species ', Mode Plot']);
xlabel('Mode 1');
ylabel('Mode 2');
plot(ellipsedata.xy(1,:)*ell_sc(1),ellipsedata.xy(2,:)*ell_sc(1),'--','Color','magenta','LineWidth',2);
plot(ellipsedata.xy(1,:)*ell_sc(2),ellipsedata.xy(2,:)*ell_sc(2),'--','Color','magenta','LineWidth',2);
plot(ellipsedata.xy(1,:),ellipsedata.xy(2,:),'-','Color','magenta','LineWidth',2);
colormap magma;

%%
figure(9001);


%G = [1*ones(num,1) ; 2*ones(num,1)];

gscatter(X(:,1), X(:,2), G)
%scatter(X(:,1), X(:,2))
axis equal, hold on

%for k=1:2
%# indices of points in this group
idx = ( G == 1 );

%# substract mean
Mu = mean( X(idx,:) );
X0 = bsxfun(@minus, X(idx,:), Mu);

STD = 1.5;                     %# 2 standard deviations
conf = 2*normcdf(STD)-1;     %# covers around 95% of population
conf = 0.5; %set to 50% for mean
scale = chi2inv(conf,2);     %# inverse chi-squared with dof=#dimensions

Cov = cov(X0) * scale;
[V D] = eig(Cov);


t = linspace(0,2*pi,100);
e = [cos(t) ; sin(t)];        %# unit circle
VV = V*sqrt(D);               %# scale eigenvectors
e = bsxfun(@plus, VV*e, Mu'); %#' project circle back to orig space

plot(e(1,:), e(2,:), 'Color','k');


%%

dm = size(U_temp{1},1);
idx = 1:dm:(dm*(length(U_temp)-1)+1);
Mode1 = cell2mat(U_temp);
Mode2 = Mode1(:,idx+1);
Mode1 = Mode1(:,idx);
figure(201)
std1 = std(abs(Mode1),[],2);
shadedErrorBar(1:length(std1),median(Mode1,2),std1,'lineprops','-o');%{'r-o','markerfacecolor','r'});
hold on;
title('Mode1')

figure(202)
std2 = std(abs(Mode2),[],2);
shadedErrorBar(1:length(std2),median(Mode2,2),std2,'lineprops','-o');%{'r-o','markerfacecolor','r'});
hold on;
title('Mode2')




%%
kfolds = [jjj iii];

saving = false;
date = '200617A';
if saving
    try
        save([species '_' date '.mat'],'A_overthresh', 'procdata', 'allruns','U_temp','thresh','kfolds','sv','S_temp')
    catch
        save([species '_' date '.mat'], 'A_overthresh', 'U_temp','thresh','kfolds','S_temp')
    end
end

[U, S, V] = svd(single(A_overthresh));
%U_temp = U;

%[U, S, V] = svd(single(A_allruns));
sv = diag(S);

stroke = U\A_overthresh;

axis
hold on;
plot_modes(U, species)

figure(900)
%ndhist(stroke(1,:),stroke(2,:),'int',1,'bins',5); axis image; colorbar;
ndhist(strokeRot(1,:),strokeRot(2,:),'int',1,'bins',5); axis image; colorbar;

title([species ', Mode Plot']);
xlabel('Mode 1');
ylabel('Mode 2');

figure(901)
for p = 1:(length(sum_ls)-1)
    stroke = U\A_overthresh(:,sum_ls(p):sum_ls(p+1));
    plot3(stroke(1,:),stroke(2,:),stroke(3,:),' .');
    hold on;
end

